{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "collapsed_sections": [
        "UUXnh11hA75x",
        "jbcW5jMLK8gK"
      ]
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "##### Copyright 2024 The IREE Authors"
      ],
      "metadata": {
        "id": "UUXnh11hA75x"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
        "# See https://llvm.org/LICENSE.txt for license information.\n",
        "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
      ],
      "metadata": {
        "cellView": "form",
        "id": "FqsvmKpjBJO2"
      },
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# <img src=\"https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png\" height=\"20px\"> Hugging Face to <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png\" height=\"20px\"> PyTorch to <img src=\"https://raw.githubusercontent.com/iree-org/iree/main/docs/website/docs/assets/images/IREE_Logo_Icon_Color.svg\" height=\"20px\"> IREE\n",
        "\n",
        "This notebook uses [iree-turbine](https://github.com/iree-org/iree-turbine) to export a pretrained [Hugging Face Transformers](https://huggingface.co/docs/transformers/) model to [IREE](https://github.com/iree-org/iree), leveraging [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.\n",
        "\n",
        "* The pretrained [whisper-small](https://huggingface.co/openai/whisper-small)\n",
        "  model is showcased here as it is small enough to fit comfortably into a Colab\n",
        "  notebook. Other pretrained models can be found at\n",
        "  https://huggingface.co/docs/transformers/index."
      ],
      "metadata": {
        "id": "38UDc27KBPD1"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Setup"
      ],
      "metadata": {
        "id": "jbcW5jMLK8gK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%capture\n",
        "#@title Uninstall existing packages\n",
        "#   This avoids some warnings when installing specific PyTorch packages below.\n",
        "!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision"
      ],
      "metadata": {
        "id": "KsPubQSvCbXd"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!python -m pip install --pre --index-url https://download.pytorch.org/whl/cpu --upgrade torch==2.5.0"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "oO1tirq2ggmO",
        "outputId": "1c10e964-1bd3-41e7-d7ce-70cf574d817b"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Looking in indexes: https://download.pytorch.org/whl/cpu\n",
            "Collecting torch==2.5.0\n",
            "  Downloading https://download.pytorch.org/whl/cpu/torch-2.5.0%2Bcpu-cp310-cp310-linux_x86_64.whl (174.7 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m174.7/174.7 MB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (3.16.1)\n",
            "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (4.12.2)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (3.4.2)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (3.1.5)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (2024.10.0)\n",
            "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (1.13.1)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch==2.5.0) (1.3.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.5.0) (3.0.2)\n",
            "Installing collected packages: torch\n",
            "  Attempting uninstall: torch\n",
            "    Found existing installation: torch 2.5.1+cu121\n",
            "    Uninstalling torch-2.5.1+cu121:\n",
            "      Successfully uninstalled torch-2.5.1+cu121\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "timm 1.0.12 requires torchvision, which is not installed.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed torch-2.5.0+cpu\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4iJFDHbsAzo4",
        "outputId": "c95e32a5-70ab-43e7-8c8c-300d37cccfd3"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting iree-turbine\n",
            "  Downloading iree_turbine-3.1.0-py3-none-any.whl.metadata (6.7 kB)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.26.4)\n",
            "Collecting iree-base-compiler (from iree-turbine)\n",
            "  Downloading iree_base_compiler-3.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.0 kB)\n",
            "Collecting iree-base-runtime (from iree-turbine)\n",
            "  Downloading iree_base_runtime-3.1.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.0 kB)\n",
            "Requirement already satisfied: Jinja2>=3.1.3 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (3.1.5)\n",
            "Collecting ml_dtypes>=0.5.0 (from iree-turbine)\n",
            "  Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)\n",
            "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (4.12.2)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2>=3.1.3->iree-turbine) (3.0.2)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from iree-base-compiler->iree-turbine) (1.13.1)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->iree-base-compiler->iree-turbine) (1.3.0)\n",
            "Downloading iree_turbine-3.1.0-py3-none-any.whl (301 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m301.7/301.7 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.7/4.7 MB\u001b[0m \u001b[31m34.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading iree_base_compiler-3.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (71.2 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.2/71.2 MB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading iree_base_runtime-3.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (8.2 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.2/8.2 MB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: ml_dtypes, iree-base-runtime, iree-base-compiler, iree-turbine\n",
            "  Attempting uninstall: ml_dtypes\n",
            "    Found existing installation: ml-dtypes 0.4.1\n",
            "    Uninstalling ml-dtypes-0.4.1:\n",
            "      Successfully uninstalled ml-dtypes-0.4.1\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "tensorflow 2.17.1 requires ml-dtypes<0.5.0,>=0.3.1, but you have ml-dtypes 0.5.1 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed iree-base-compiler-3.1.0 iree-base-runtime-3.1.0 iree-turbine-3.1.0 ml_dtypes-0.5.1\n"
          ]
        }
      ],
      "source": [
        "!python -m pip install iree-turbine"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Report version information\n",
        "!echo \"Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)\"\n",
        "\n",
        "!echo -e \"\\nInstalled IREE, compiler version information:\"\n",
        "!iree-compile --version\n",
        "\n",
        "import torch\n",
        "print(\"\\nInstalled PyTorch, version:\", torch.__version__)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nkVLzRpcDnVL",
        "outputId": "210a54b9-4044-4426-f9ee-09d5fd23839c"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Installed iree-turbine, Version: 3.1.0\n",
            "\n",
            "Installed IREE, compiler version information:\n",
            "IREE (https://iree.dev):\n",
            "  IREE compiler version 3.1.0rc20250107 @ d2242207764230ad398585a5771f9d54ce91b4c8\n",
            "  LLVM version 20.0.0git\n",
            "  Optimized build\n",
            "\n",
            "Installed PyTorch, version: 2.5.0+cpu\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Load and run whisper-small"
      ],
      "metadata": {
        "id": "I0OfTFxwOud1"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Load the pretrained model from https://huggingface.co/openai/whisper-small.\n",
        "\n",
        "See also:\n",
        "\n",
        "* Model card: https://huggingface.co/docs/transformers/model_doc/whisper\n",
        "* Test case in [SHARK-TestSuite](https://github.com/nod-ai/SHARK-TestSuite/): [`pytorch/models/whisper-small/model.py`](https://github.com/nod-ai/SHARK-TestSuite/blob/main/e2eshark/pytorch/models/whisper-small/model.py)"
      ],
      "metadata": {
        "id": "94Ji4URLT_xM"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
        "\n",
        "# https://huggingface.co/docs/transformers/model_doc/auto\n",
        "# AutoModelForCausalLM -> WhisperForCausalLM\n",
        "# AutoTokenizer        -> WhisperTokenizerFast\n",
        "\n",
        "modelname = \"openai/whisper-small\"\n",
        "tokenizer = AutoTokenizer.from_pretrained(modelname)\n",
        "\n",
        "# Some of the options here affect how the model is exported. See the test cases\n",
        "# at https://github.com/nod-ai/SHARK-TestSuite/tree/main/e2eshark/pytorch/models\n",
        "# for other options that may be useful to set.\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    modelname,\n",
        "    output_attentions=False,\n",
        "    output_hidden_states=False,\n",
        "    attn_implementation=\"eager\",\n",
        "    torchscript=True,\n",
        ")\n",
        "\n",
        "# This is just a simple demo to get some data flowing through the model.\n",
        "# Depending on this model and what input it expects (text, image, audio, etc.)\n",
        "# this might instead use a specific Processor class. For Whisper,\n",
        "# WhisperProcessor runs audio input pre-processing and output post-processing.\n",
        "example_prompt = \"Hello world!\"\n",
        "example_encoding = tokenizer(example_prompt, return_tensors=\"pt\")\n",
        "example_input = example_encoding[\"input_ids\"].cpu()\n",
        "example_args = (example_input,)"
      ],
      "metadata": {
        "id": "HLbfUuoBPHgH",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 476,
          "referenced_widgets": [
            "aaaaa863923e4b6fb92380ff063677a7",
            "37ee0e0bcecc44a1aaf25a8e66d752ad",
            "139c5c777da84307a28ffc7de9e037f6",
            "ea6b64fa6c5a4d0aa8797ac9dfa9051a",
            "28125f8e13a14bfaad7c645590360b22",
            "4b4afd6e36a94799856aa8dd2243f9a2",
            "7e2f1ea6da7d430fa5df1fc3b9733e10",
            "5257eab4b9cd42bbb106404bfe428903",
            "9ae9e376cec54470b398a9cf5fa7b9fd",
            "b0a36845fd614468a45de443b2e5c0f8",
            "757a8987af504be89d2cc7d058728d9c",
            "d88e2b3299f8431bbe7af6e1232ce54c",
            "d1d3a7166d0f40b4bb26a146e8b76d5a",
            "f105605987244b56b6e33f162fbe6930",
            "70a4da0f7f224aa9b828c4493fe31101",
            "4e8a310daae6485cb53c853bcbb6b029",
            "9c966c41c6eb4407b6c8751c29b9d082",
            "86685457be39483fa223541a8e51a79e",
            "cc26cd911fb84bc19bdb782060138df4",
            "a8fc625551bb40c9aedb37d548837cd2",
            "ef484e2a7891478d822929c4728dfdd3",
            "c2a512da304643e9ae86eb6b1c434934",
            "d24fd8bd3c6349b8a265a18f96901458",
            "a217516b31244cb194bb47f4da51ae6c",
            "8b50957482094ec58561ed62fe53c720",
            "87e3626416c440a785c3898baf2c8bce",
            "ecafe4f7ddc64973a6fce7a4d0fcdcb3",
            "cfa6bedf2057488ca273dea84107cdc9",
            "6e7a805a79c749f48d3bfa1028e3de70",
            "fff09ea7e19b47f29304f9f315425884",
            "9931ae28665347c883d6b4723f405bfb",
            "f454754510404b61b18a8f87cd8ba1ce",
            "da14fb80f6994378bf405b38c3f86bab",
            "560e52debb244737b1cb8f3088506e80",
            "4fe7fe3078b24fc3811072b7790a9371",
            "bec9bf346af9464b8fb120fbbaf2fefd",
            "ffd5fca2e0b84317a473e358e85f3d77",
            "114aa03e366f46188a771c98607e5adb",
            "e8dbea7f1cd0443ca3cbe114e24ff3da",
            "ab512519d4bb4d28b6109a25d9bd6b88",
            "4ead0d0ec9994682ab8cc9ef027eaae2",
            "3e61b4815dd14f619b90c715d62df347",
            "c1420dce3e7246f3866a8ce85474fe5f",
            "fa9daa6d53564dfdb7af8593edc69884",
            "a6b1386c9842438ca5801134d26f0b51",
            "c21ed7621a004d9bb24e9d68c3a76a6f",
            "7b6e231cec8946118d7bcd745f518010",
            "2fd77a6a408a41f1be6068ef77057a28",
            "2c2b92f12df041b0b973ae9793cdc1ab",
            "446e30992e244f48a68d9130e82a7126",
            "9aece5f7db4646a7a7f46a186cde18eb",
            "6faf78708ec44e5f8d5db701938965c3",
            "3a26fee124d442158961fd5c1b28e5bd",
            "549fc99a11d44ef1ab48c64b742d58e7",
            "7516468965394c4b8bcbc7ea3db3b457",
            "d476cb392ec5423cba72f46757f1df1a",
            "9fbede93702942c492488056485bdb6c",
            "e5b32ffa050946469cc375b9234e35ae",
            "4d18beace81f48c08540812193fd5244",
            "9076e8da1009489291af97753dcae650",
            "5f0a6e2ed40a47cc8016e1a11100579e",
            "df19cd089b854db6ad2230f0a457aca8",
            "b671614693b141f9821ef7b78ee98ae1",
            "f3e226744d6e45918d4a48b924389ad7",
            "6f4e119ee815411094a7d9d5311f10b5",
            "77503681f26b46419157af3d49a71bcb",
            "16778bf77b8942939e323b655ac4dfa6",
            "67057ee8ce464d64a7769ade5d7479dd",
            "e05d7305000c4b1091631d7b15f7900a",
            "c9c0abecdc0e4462a8a72e47bfb1e53d",
            "02202837100e448383e6615758556655",
            "86ffbd1afeab411fbd322c909cb21a5d",
            "5870517d9c124931886a88a310c2386e",
            "59dc9a70d31c492788508f67cd975365",
            "cf5a0cc0a32d4df2a871ec642d2de5da",
            "9a9554c7d1f04d9ab14df042cfefbedd",
            "33a04a7c30dd42de9cfe6a43e988604d",
            "83e8699174dd43bc9b646d6d49c993b3",
            "27dbc17542ba451888f1e847b59f2da7",
            "5a480878ee0d441da7ab360d1c93fbc6",
            "d646f8677a1b4ca190f2410a8cfc05b7",
            "1c4cd997e1104a4ba0f3a15b4ce7dea0",
            "342b91aaba4d4df8aa567779d1a6f4e7",
            "c7487e228a7048ff9e9be026dc5a9f46",
            "60cec20206c243efa658a94c7278719c",
            "57f8627eb25c4b84a74c4bf71c6c122f",
            "3772acd3cc9d4e61ab2a4bf0f4a15774",
            "f7a3e9b12c4b461e972113b8880aa985",
            "a8fdc4ca612b47c68fd130b51bcd1ece",
            "b5a84ce7032343b9b040a03e5432d96f",
            "b893ddddacf74e7c8ca40666fb84c24e",
            "9b44401b93664666a42c56d3165de181",
            "c976d288588145069099727ab5183da6",
            "b952b8ec0ebb4727802daddb7a3d5f4a",
            "18293fa01eb14c8b9d7f2b00669ba82a",
            "eab894861d1f406c973b525366a8e157",
            "b727f95de2e245019a61ac5c9466ece4",
            "a23bf4d95ebe4af1991fe1531b6b7b2f",
            "fe045669fdc147a5a6bde04b4f31fcef",
            "ccee9eaca1d944099ee97f8c0a5790b5",
            "675803d6f0b94b07b3b465b427547a00",
            "cff95693504d4e7eac5fdfe972cf7e12",
            "566899e9d7c04a7c954f7186d988c1e4",
            "d994cd6184ee402ca6e4ae0b7db8faea",
            "d9d7829fe66049379ecf88ce7b385c34",
            "fb82acd9ae0f441d9c1dcf87b47f3486",
            "ef1916aba2f449cbabc64248ef9cc95f",
            "58d678849b444dfaa467f87a1b7bd9fc",
            "bb1027ed5dbb4fb68f90480e85cde62c",
            "79cf395270734f4a926dd3fc165f65e2"
          ]
        },
        "outputId": "c33917e5-8ec5-4e03-85c3-9424a529fac9"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n",
            "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
            "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
            "You will be able to reuse this secret in all of your notebooks.\n",
            "Please note that authentication is recommended but still optional to access public models or datasets.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "aaaaa863923e4b6fb92380ff063677a7"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "vocab.json:   0%|          | 0.00/836k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "d88e2b3299f8431bbe7af6e1232ce54c"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "d24fd8bd3c6349b8a265a18f96901458"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "560e52debb244737b1cb8f3088506e80"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "a6b1386c9842438ca5801134d26f0b51"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "d476cb392ec5423cba72f46757f1df1a"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "16778bf77b8942939e323b655ac4dfa6"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "config.json:   0%|          | 0.00/1.97k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "83e8699174dd43bc9b646d6d49c993b3"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "model.safetensors:   0%|          | 0.00/967M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "a8fdc4ca612b47c68fd130b51bcd1ece"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Some weights of WhisperForCausalLM were not initialized from the model checkpoint at openai/whisper-small and are newly initialized: ['proj_out.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/3.87k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "ccee9eaca1d944099ee97f8c0a5790b5"
            }
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Test exporting using [`torch.export()`](https://pytorch.org/docs/stable/export.html#torch.export.export). If `torch.export` works, `aot.export()` from Turbine should work as well."
      ],
      "metadata": {
        "id": "vQlF_ua3UNvo"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "exported_program = torch.export.export(model, example_args)"
      ],
      "metadata": {
        "id": "-4LykgffY9uH"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Export using the simple [`aot.export()`](https://iree.dev/guides/ml-frameworks/pytorch/#simple-api) API from Turbine."
      ],
      "metadata": {
        "id": "wXZI4GliUazA"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import iree.turbine.aot as aot\n",
        "# Note: aot.export() wants the example args to be unpacked.\n",
        "whisper_compiled_module = aot.export(model, *example_args)"
      ],
      "metadata": {
        "id": "R7-rN_z2Y_5z"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Compile using Turbine/IREE then run the program."
      ],
      "metadata": {
        "id": "YK3hjpTpUdhc"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "binary = whisper_compiled_module.compile(save_to=None)\n",
        "\n",
        "import iree.runtime as ireert\n",
        "config = ireert.Config(\"local-task\")\n",
        "vm_module = ireert.load_vm_module(\n",
        "    ireert.VmModule.wrap_buffer(config.vm_instance, binary.map_memory()),\n",
        "    config,\n",
        ")\n",
        "\n",
        "iree_outputs = vm_module.main(example_args[0])\n",
        "print(iree_outputs[0].to_host())"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "FctBxxEXZBan",
        "outputId": "89733fb1-3d5a-4258-9da6-394d56ccd230"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[[[  5.8126216   3.9667568   4.5749426 ...   2.7658575   2.6436937\n",
            "     1.5479789]\n",
            "  [  7.5634375   6.029962    5.1000347 ...   6.432704    6.101554\n",
            "     6.4348   ]\n",
            "  [  0.9380306  -4.4696145  -4.012748  ...  -6.2486286  -7.7917867\n",
            "    -6.8453736]\n",
            "  [  0.7450936  -3.7631674  -7.4870253 ...  -6.734828   -6.966235\n",
            "   -10.022404 ]\n",
            "  [ -0.9628601  -3.510199   -6.015854  ...  -7.116391   -6.7086434\n",
            "   -10.225704 ]\n",
            "  [  3.347097    2.4927166  -3.3042672 ...  -1.5709717  -1.8455461\n",
            "    -2.9991992]]]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Run the program using native PyTorch to compare outputs."
      ],
      "metadata": {
        "id": "5WuFpyFfUjh8"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "torch_outputs = model(example_args[0])\n",
        "print(torch_outputs[0].detach().numpy())"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "IxPYkcPycG4r",
        "outputId": "f21bc1a0-ddc3-49ec-a122-927ad4e4b54b"
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[[[  5.8126183    3.9667587    4.5749483  ...   2.7658575    2.643694\n",
            "     1.5479784 ]\n",
            "  [  7.563436     6.029952     5.100036   ...   6.4327083    6.101557\n",
            "     6.4348083 ]\n",
            "  [  0.93802685  -4.469646    -4.012787   ...  -6.2486415   -7.7918167\n",
            "    -6.8453975 ]\n",
            "  [  0.74507916  -3.763197    -7.487034   ...  -6.734877    -6.966276\n",
            "   -10.022424  ]\n",
            "  [ -0.96288276  -3.510221    -6.0158725  ...  -7.1164136   -6.708687\n",
            "   -10.225745  ]\n",
            "  [  3.3470666    2.492654    -3.304323   ...  -1.5709934   -1.8455791\n",
            "    -2.9992423 ]]]\n"
          ]
        }
      ]
    }
  ]
}